{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e6aa517b",
   "metadata": {},
   "source": [
    "# MindNLP-bigbird_pegasus模型微调\n",
    "基础模型：google/bigbird-pegasus-large-arxiv\n",
    "tokenizer：google/bigbird-pegasus-large-arxiv\n",
    "微调数据集：databricks/databricks-dolly-15k\n",
    "硬件：Ascend910B1\n",
    "环境\n",
    "| Software    | Version                     |\n",
    "| ----------- | --------------------------- |\n",
    "| MindSpore   | MindSpore 2.4.0             |\n",
    "| MindSpore   | MindSpore 0.4.1             |\n",
    "| CANN        | 8.0                         |\n",
    "| Python      | Python 3.9                  |\n",
    "| OS platform | Ubuntu 5.4.0-42-generic     |\n",
    "\n",
    "## instruction\n",
    "BigBird-Pegasus 是基于 BigBird 和 Pegasus 的混合模型，结合了两者的优势，专为处理长文本序列设计。BigBird 是一种基于 Transformer 的模型，通过稀疏注意力机制处理长序列，降低计算复杂度。Pegasus 是专为文本摘要设计的模型，通过自监督预训练任务（GSG）提升摘要生成能力。BigBird-Pegasus 结合了 BigBird 的长序列处理能力和 Pegasus 的摘要生成能力，适用于长文本摘要任务，如学术论文和长文档摘要。\n",
    "Databricks Dolly 15k 是由 Databricks 发布的高质量指令微调数据集，包含约 15,000 条人工生成的指令-响应对，用于训练和评估对话模型。是专门为NLP模型微调设计的数据集。\n",
    "## train loss\n",
    "\n",
    "对比微调训练的loss变化\n",
    "\n",
    "| epoch | mindnlp+mindspore | transformer+torch（4060） |\n",
    "| ----- | ----------------- | ------------------------- |\n",
    "| 1     | 2.0958            | 8.7301                    |\n",
    "| 2     | 1.969             | 8.1557                    |\n",
    "| 3     | 1.8755            | 7.7516                    |\n",
    "| 4     | 1.8264            | 7.5017                    |\n",
    "| 5     | 1.7349            | 7.2614                    |\n",
    "| 6     | 1.678             | 7.0559                    |\n",
    "| 7     | 1.6937            | 6.8405                    |\n",
    "| 8     | 1.654             | 6.7297                    |\n",
    "| 9     | 1.6365            | 6.7136                    |\n",
    "| 10    | 1.7003            | 6.6279                    |\n",
    "\n",
    "## eval loss\n",
    "\n",
    "对比评估得分\n",
    "\n",
    "| epoch | mindnlp+mindspore  | transformer+torch（4060） |\n",
    "| ----- | ------------------ | ------------------------- |\n",
    "| 1     | 2.1257965564727783 | 6.3235931396484375        |\n",
    "\n",
    "**首先运行以下脚本配置环境**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8361c5cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: http://mirrors.aliyun.com/pypi/simple/\n",
      "Collecting mindnlp\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/0f/a8/5a072852d28a51417b5e330b32e6ae5f26b491ef01a15ba968e77f785e69/mindnlp-0.4.0-py3-none-any.whl (8.4 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.4/8.4 MB\u001b[0m \u001b[31m1.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: mindspore>=2.2.14 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindnlp) (2.3.0)\n",
      "Requirement already satisfied: tqdm in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindnlp) (4.65.0)\n",
      "Requirement already satisfied: requests in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindnlp) (2.31.0)\n",
      "Collecting datasets (from mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/4c/37/22ef7675bef4ffe9577b937ddca2e22791534cbbe11c30714972a91532dc/datasets-3.3.2-py3-none-any.whl (485 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m485.4/485.4 kB\u001b[0m \u001b[31m1.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting evaluate (from mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/a2/e7/cbca9e2d2590eb9b5aa8f7ebabe1beb1498f9462d2ecede5c9fd9735faaf/evaluate-0.4.3-py3-none-any.whl (84 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.0/84.0 kB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting tokenizers==0.19.1 (from mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/ba/26/139bd2371228a0e203da7b3e3eddcb02f45b2b7edd91df00e342e4b55e13/tokenizers-0.19.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (3.6 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.6/3.6 MB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hCollecting safetensors (from mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/5d/9a/add3e6fef267658075c5a41573c26d42d80c935cdc992384dfae435feaef/safetensors-0.5.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (459 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m459.5/459.5 kB\u001b[0m \u001b[31m1.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: sentencepiece in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindnlp) (0.1.99)\n",
      "Requirement already satisfied: regex in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindnlp) (2023.10.3)\n",
      "Collecting addict (from mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/6a/00/b08f23b7d7e1e14ce01419a467b583edbb93c6cdb8654e54a9cc579cd61f/addict-2.4.0-py3-none-any.whl (3.8 kB)\n",
      "Requirement already satisfied: ml-dtypes in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindnlp) (0.2.0)\n",
      "Collecting pyctcdecode (from mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/a5/8a/93e2118411ae5e861d4f4ce65578c62e85d0f1d9cb389bd63bd57130604e/pyctcdecode-0.5.0-py2.py3-none-any.whl (39 kB)\n",
      "Requirement already satisfied: jieba in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindnlp) (0.42.1)\n",
      "Collecting pytest==7.2.0 (from mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/67/68/a5eb36c3a8540594b6035e6cdae40c1ef1b6a2bfacbecc3d1a544583c078/pytest-7.2.0-py3-none-any.whl (316 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m316.8/316.8 kB\u001b[0m \u001b[31m866.1 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting pillow>=10.0.0 (from mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/0c/55/f182db572b28bd833b8e806f933f782ceb2df64c40e4d8bd3d4226a46eca/pillow-11.1.0-cp39-cp39-manylinux_2_28_aarch64.whl (4.4 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.4/4.4 MB\u001b[0m \u001b[31m1.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: attrs>=19.2.0 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (23.1.0)\n",
      "Collecting iniconfig (from pytest==7.2.0->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl (5.9 kB)\n",
      "Requirement already satisfied: packaging in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (23.2)\n",
      "Collecting pluggy<2.0,>=0.12 (from pytest==7.2.0->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl (20 kB)\n",
      "Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (1.1.3)\n",
      "Collecting tomli>=1.0.0 (from pytest==7.2.0->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl (14 kB)\n",
      "Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from tokenizers==0.19.1->mindnlp) (0.18.0)\n",
      "Requirement already satisfied: numpy>=1.17.0 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp) (1.23.5)\n",
      "Requirement already satisfied: protobuf>=3.13.0 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp) (3.20.3)\n",
      "Requirement already satisfied: asttokens>=2.0.4 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp) (2.4.1)\n",
      "Requirement already satisfied: scipy>=1.5.4 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp) (1.11.3)\n",
      "Requirement already satisfied: psutil>=5.6.1 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp) (5.9.5)\n",
      "Requirement already satisfied: astunparse>=1.6.3 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindspore>=2.2.14->mindnlp) (1.6.3)\n",
      "Requirement already satisfied: filelock in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from datasets->mindnlp) (3.13.1)\n",
      "Collecting pyarrow>=15.0.0 (from datasets->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/f2/87/4ef05a088b18082cde4950bdfca752dd31effb3ec201b8026e4816d0f3fa/pyarrow-19.0.1-cp39-cp39-manylinux_2_28_aarch64.whl (40.5 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.5/40.5 MB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hCollecting dill<0.3.9,>=0.3.0 (from datasets->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl (116 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: pandas in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from datasets->mindnlp) (2.1.2)\n",
      "Collecting requests (from mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl (64 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.9/64.9 kB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting tqdm (from mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl (78 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.5/78.5 kB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting xxhash (from datasets->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/b4/92/9ac297e3487818f429bcf369c1c6a097edf5b56ed6fc1feff4c1882e87ef/xxhash-3.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (220 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m220.6/220.6 kB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting multiprocess<0.70.17 (from datasets->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl (133 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m133.4/133.4 kB\u001b[0m \u001b[31m1.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: fsspec<=2024.12.0,>=2023.1.0 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets->mindnlp) (2023.10.0)\n",
      "Collecting aiohttp (from datasets->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/1d/d5/ab9ad5242c7920e224cbdc1c9bec62a79f75884049ccb86edb64225e4c0f/aiohttp-3.11.13-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (1.6 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hCollecting huggingface-hub<1.0,>=0.16.4 (from tokenizers==0.19.1->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/ae/05/75b90de9093de0aadafc868bb2fa7c57651fd8f45384adf39bd77f63980d/huggingface_hub-0.29.1-py3-none-any.whl (468 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m468.0/468.0 kB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: pyyaml>=5.1 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from datasets->mindnlp) (6.0.1)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from requests->mindnlp) (3.3.2)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from requests->mindnlp) (3.4)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from requests->mindnlp) (2.0.7)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from requests->mindnlp) (2023.7.22)\n",
      "Collecting pygtrie<3.0,>=2.1 (from pyctcdecode->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/ec/cd/bd196b2cf014afb1009de8b0f05ecd54011d881944e62763f3c1b1e8ef37/pygtrie-2.5.0-py3-none-any.whl (25 kB)\n",
      "Collecting hypothesis<7,>=6.14 (from pyctcdecode->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/3e/15/234573ed76ab2b065c562c72b25ade28ed9d46d0efd347a8599a384521a1/hypothesis-6.127.5-py3-none-any.whl (483 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m483.4/483.4 kB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: six>=1.12.0 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from asttokens>=2.0.4->mindspore>=2.2.14->mindnlp) (1.16.0)\n",
      "Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from astunparse>=1.6.3->mindspore>=2.2.14->mindnlp) (0.41.3)\n",
      "Collecting aiohappyeyeballs>=2.3.0 (from aiohttp->datasets->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/44/4c/03fb05f56551828ec67ceb3665e5dc51638042d204983a03b0a1541475b6/aiohappyeyeballs-2.4.6-py3-none-any.whl (14 kB)\n",
      "Collecting aiosignal>=1.1.2 (from aiohttp->datasets->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/ec/6a/bc7e17a3e87a2985d3e8f4da4cd0f481060eb78fb08596c42be62c90a4d9/aiosignal-1.3.2-py2.py3-none-any.whl (7.6 kB)\n",
      "Collecting async-timeout<6.0,>=4.0 (from aiohttp->datasets->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl (6.2 kB)\n",
      "Collecting frozenlist>=1.1.1 (from aiohttp->datasets->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/08/04/e2fddc92135276e07addbc1cf413acffa0c2d848b3e54cacf684e146df49/frozenlist-1.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (241 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m241.8/241.8 kB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting multidict<7.0,>=4.5 (from aiohttp->datasets->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/89/87/d451d45aab9e422cb0fb2f7720c31a4c1d3012c740483c37f642eba568fb/multidict-6.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (126 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m126.2/126.2 kB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting propcache>=0.2.0 (from aiohttp->datasets->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/e6/65/09b1bacf723721e36a84034ff0a4d64d13c7ddb92cfefe9c0b861886f814/propcache-0.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (208 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m208.1/208.1 kB\u001b[0m \u001b[31m1.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hCollecting yarl<2.0,>=1.17.0 (from aiohttp->datasets->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/0f/4f/438c9fd668954779e48f08c0688ee25e0673380a21bb1e8ccc56de5b55d7/yarl-1.18.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (317 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m317.3/317.3 kB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: typing-extensions>=3.7.4.3 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from huggingface-hub<1.0,>=0.16.4->tokenizers==0.19.1->mindnlp) (4.8.0)\n",
      "Collecting sortedcontainers<3.0.0,>=2.1.0 (from hypothesis<7,>=6.14->pyctcdecode->mindnlp)\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2.8.2)\n",
      "Requirement already satisfied: pytz>=2020.1 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2023.3.post1)\n",
      "Requirement already satisfied: tzdata>=2022.1 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2023.3)\n",
      "\u001b[33mDEPRECATION: moxing-framework 2.1.16.2ae09d45 has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of moxing-framework or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n",
      "\u001b[0mInstalling collected packages: sortedcontainers, pygtrie, addict, xxhash, tqdm, tomli, safetensors, requests, pyarrow, propcache, pluggy, pillow, multidict, iniconfig, hypothesis, frozenlist, dill, async-timeout, aiohappyeyeballs, yarl, pytest, pyctcdecode, multiprocess, huggingface-hub, aiosignal, tokenizers, aiohttp, datasets, evaluate, mindnlp\n",
      "  Attempting uninstall: tqdm\n",
      "    Found existing installation: tqdm 4.65.0\n",
      "    Uninstalling tqdm-4.65.0:\n",
      "      Successfully uninstalled tqdm-4.65.0\n",
      "  Attempting uninstall: requests\n",
      "    Found existing installation: requests 2.31.0\n",
      "    Uninstalling requests-2.31.0:\n",
      "      Successfully uninstalled requests-2.31.0\n",
      "  Attempting uninstall: pillow\n",
      "    Found existing installation: Pillow 9.0.1\n",
      "    Uninstalling Pillow-9.0.1:\n",
      "      Successfully uninstalled Pillow-9.0.1\n",
      "  Attempting uninstall: huggingface-hub\n",
      "    Found existing installation: huggingface-hub 0.18.0\n",
      "    Uninstalling huggingface-hub-0.18.0:\n",
      "      Successfully uninstalled huggingface-hub-0.18.0\n",
      "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
      "gradio 3.50.2 requires pillow<11.0,>=8.0, but you have pillow 11.1.0 which is incompatible.\n",
      "imageio 2.31.6 requires pillow<10.1.0,>=8.3.2, but you have pillow 11.1.0 which is incompatible.\n",
      "mindtorch 0.3.0 requires tqdm==4.65.0, but you have tqdm 4.67.1 which is incompatible.\u001b[0m\u001b[31m\n",
      "\u001b[0mSuccessfully installed addict-2.4.0 aiohappyeyeballs-2.4.6 aiohttp-3.11.13 aiosignal-1.3.2 async-timeout-5.0.1 datasets-3.3.2 dill-0.3.8 evaluate-0.4.3 frozenlist-1.5.0 huggingface-hub-0.29.1 hypothesis-6.127.5 iniconfig-2.0.0 mindnlp-0.4.0 multidict-6.1.0 multiprocess-0.70.16 pillow-11.1.0 pluggy-1.5.0 propcache-0.3.0 pyarrow-19.0.1 pyctcdecode-0.5.0 pygtrie-2.5.0 pytest-7.2.0 requests-2.32.3 safetensors-0.5.3 sortedcontainers-2.4.0 tokenizers-0.19.1 tomli-2.2.1 tqdm-4.67.1 xxhash-3.5.0 yarl-1.18.3\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0mLooking in indexes: http://mirrors.aliyun.com/pypi/simple/\n",
      "Collecting mindspore==2.4\n",
      "  Downloading http://mirrors.aliyun.com/pypi/packages/1b/e4/87dc1ae146f0715fa0ae9c04aab4cb44d07d971cb643c9460d0050d6a031/mindspore-2.4.0-cp39-none-any.whl (333.7 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m333.7/333.7 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:06\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: numpy<2.0.0,>=1.20.0 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindspore==2.4) (1.23.5)\n",
      "Requirement already satisfied: protobuf>=3.13.0 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindspore==2.4) (3.20.3)\n",
      "Requirement already satisfied: asttokens>=2.0.4 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindspore==2.4) (2.4.1)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindspore==2.4) (11.1.0)\n",
      "Requirement already satisfied: scipy>=1.5.4 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindspore==2.4) (1.11.3)\n",
      "Requirement already satisfied: packaging>=20.0 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindspore==2.4) (23.2)\n",
      "Requirement already satisfied: psutil>=5.6.1 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindspore==2.4) (5.9.5)\n",
      "Requirement already satisfied: astunparse>=1.6.3 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindspore==2.4) (1.6.3)\n",
      "Requirement already satisfied: safetensors>=0.4.0 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from mindspore==2.4) (0.5.3)\n",
      "Requirement already satisfied: six>=1.12.0 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from asttokens>=2.0.4->mindspore==2.4) (1.16.0)\n",
      "Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages (from astunparse>=1.6.3->mindspore==2.4) (0.41.3)\n",
      "\u001b[33mDEPRECATION: moxing-framework 2.1.16.2ae09d45 has a non-standard version number. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of moxing-framework or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063\u001b[0m\u001b[33m\n",
      "\u001b[0mInstalling collected packages: mindspore\n",
      "  Attempting uninstall: mindspore\n",
      "    Found existing installation: mindspore 2.3.0\n",
      "    Uninstalling mindspore-2.3.0:\n",
      "      Successfully uninstalled mindspore-2.3.0\n",
      "Successfully installed mindspore-2.4.0\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "# 在Ascend910B1环境需要额外安装以下\n",
    "# !pip install mindnlp\n",
    "# !pip install mindspore==2.4\n",
    "# !export LD_PRELOAD=$LD_PRELOAD:/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/torch.libs/libgomp-74ff64e9.so.1.0.0\n",
    "# !yum install libsndfile"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d780a67a",
   "metadata": {},
   "source": [
    "## 导入库\n",
    "注意这里曾经导入了多个Tokenizer进行过测试。\n",
    "要设置mindspore工作环境为Ascend。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d127981e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] GE_ADPT(37,ffff8709e010,python):2025-03-04-11:16:41.325.592 [mindspore/ccsrc/utils/dlopen_macro.h:163] DlsymAscend] Dynamically load symbol aclmdlBundleGetModelId failed, result = /usr/local/Ascend/ascend-toolkit/latest/lib64/libascendcl.so: undefined symbol: aclmdlBundleGetModelId\n",
      "[WARNING] GE_ADPT(37,ffff8709e010,python):2025-03-04-11:16:41.325.674 [mindspore/ccsrc/utils/dlopen_macro.h:163] DlsymAscend] Dynamically load symbol aclmdlBundleLoadFromMem failed, result = /usr/local/Ascend/ascend-toolkit/latest/lib64/libascendcl.so: undefined symbol: aclmdlBundleLoadFromMem\n",
      "[WARNING] GE_ADPT(37,ffff8709e010,python):2025-03-04-11:16:41.325.715 [mindspore/ccsrc/utils/dlopen_macro.h:163] DlsymAscend] Dynamically load symbol aclmdlBundleUnload failed, result = /usr/local/Ascend/ascend-toolkit/latest/lib64/libascendcl.so: undefined symbol: aclmdlBundleUnload\n",
      "[WARNING] GE_ADPT(37,ffff8709e010,python):2025-03-04-11:16:41.325.909 [mindspore/ccsrc/utils/dlopen_macro.h:163] DlsymAscend] Dynamically load symbol aclrtGetMemUceInfo failed, result = /usr/local/Ascend/ascend-toolkit/latest/lib64/libascendcl.so: undefined symbol: aclrtGetMemUceInfo\n",
      "[WARNING] GE_ADPT(37,ffff8709e010,python):2025-03-04-11:16:41.325.926 [mindspore/ccsrc/utils/dlopen_macro.h:163] DlsymAscend] Dynamically load symbol aclrtDeviceTaskAbort failed, result = /usr/local/Ascend/ascend-toolkit/latest/lib64/libascendcl.so: undefined symbol: aclrtDeviceTaskAbort\n",
      "[WARNING] GE_ADPT(37,ffff8709e010,python):2025-03-04-11:16:41.325.941 [mindspore/ccsrc/utils/dlopen_macro.h:163] DlsymAscend] Dynamically load symbol aclrtMemUceRepair failed, result = /usr/local/Ascend/ascend-toolkit/latest/lib64/libascendcl.so: undefined symbol: aclrtMemUceRepair\n",
      "[WARNING] GE_ADPT(37,ffff8709e010,python):2025-03-04-11:16:41.327.779 [mindspore/ccsrc/utils/dlopen_macro.h:163] DlsymAscend] Dynamically load symbol acltdtCleanChannel failed, result = /usr/local/Ascend/ascend-toolkit/latest/lib64/libacl_tdt_channel.so: undefined symbol: acltdtCleanChannel\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:41.550.830 [mindspore/run_check/_check_version.py:327] MindSpore version 2.4.0 and Ascend AI software package (Ascend Data Center Solution)version 7.2 does not match, the version of software package expect one of ['7.3', '7.5']. Please refer to the match info on: https://www.mindspore.cn/install\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:41.554.596 [mindspore/run_check/_check_version.py:396] Can not find the tbe operator implementation(need by mindspore-ascend). Please check whether the Environment Variable PYTHONPATH is set. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:44.300.46 [mindspore/run_check/_check_version.py:345] MindSpore version 2.4.0 and \"te\" wheel package version 7.2 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:44.341.43 [mindspore/run_check/_check_version.py:352] MindSpore version 2.4.0 and \"hccl\" wheel package version 7.2 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:44.358.17 [mindspore/run_check/_check_version.py:366] Please pay attention to the above warning, countdown: 3\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:45.385.67 [mindspore/run_check/_check_version.py:366] Please pay attention to the above warning, countdown: 2\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:46.419.87 [mindspore/run_check/_check_version.py:366] Please pay attention to the above warning, countdown: 1\n",
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Building prefix dict from the default dictionary ...\n",
      "Dumping model to file cache /tmp/jieba.cache\n",
      "Loading model cost 1.375 seconds.\n",
      "Prefix dict has been built successfully.\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:55.820.621 [mindspore/run_check/_check_version.py:327] MindSpore version 2.4.0 and Ascend AI software package (Ascend Data Center Solution)version 7.2 does not match, the version of software package expect one of ['7.3', '7.5']. Please refer to the match info on: https://www.mindspore.cn/install\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:55.827.619 [mindspore/run_check/_check_version.py:396] Can not find the tbe operator implementation(need by mindspore-ascend). Please check whether the Environment Variable PYTHONPATH is set. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:55.828.586 [mindspore/run_check/_check_version.py:345] MindSpore version 2.4.0 and \"te\" wheel package version 7.2 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:55.829.144 [mindspore/run_check/_check_version.py:352] MindSpore version 2.4.0 and \"hccl\" wheel package version 7.2 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:55.829.808 [mindspore/run_check/_check_version.py:366] Please pay attention to the above warning, countdown: 3\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:56.831.621 [mindspore/run_check/_check_version.py:366] Please pay attention to the above warning, countdown: 2\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:57.834.664 [mindspore/run_check/_check_version.py:366] Please pay attention to the above warning, countdown: 1\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:58.839.664 [mindspore/run_check/_check_version.py:327] MindSpore version 2.4.0 and Ascend AI software package (Ascend Data Center Solution)version 7.2 does not match, the version of software package expect one of ['7.3', '7.5']. Please refer to the match info on: https://www.mindspore.cn/install\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:58.843.964 [mindspore/run_check/_check_version.py:396] Can not find the tbe operator implementation(need by mindspore-ascend). Please check whether the Environment Variable PYTHONPATH is set. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:58.845.048 [mindspore/run_check/_check_version.py:345] MindSpore version 2.4.0 and \"te\" wheel package version 7.2 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:58.845.711 [mindspore/run_check/_check_version.py:352] MindSpore version 2.4.0 and \"hccl\" wheel package version 7.2 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:58.846.365 [mindspore/run_check/_check_version.py:366] Please pay attention to the above warning, countdown: 3\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:16:59.848.213 [mindspore/run_check/_check_version.py:366] Please pay attention to the above warning, countdown: 2\n",
      "[WARNING] ME(37:281472947314704,MainProcess):2025-03-04-11:17:00.851.249 [mindspore/run_check/_check_version.py:366] Please pay attention to the above warning, countdown: 1\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from mindnlp.transformers import (\n",
    "    BigBirdPegasusForCausalLM, \n",
    "    PegasusTokenizer,\n",
    "    AutoTokenizer\n",
    ")\n",
    "from datasets import load_dataset, DatasetDict\n",
    "from mindspore.dataset import GeneratorDataset\n",
    "from mindnlp.engine import Trainer, TrainingArguments\n",
    "import mindspore as ms\n",
    "# 设置运行模式和设备\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE, device_target=\"Ascend\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dbcec2d3",
   "metadata": {},
   "source": [
    "## 处理数据集\n",
    "这里为了快速多次微调，数据集经过处理后保存到本地。需要注意的是这里使用BigBirdPegasusForCausalLM，使用的是语言模型，需要将数据集进行处理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "caec8504",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义数据集保存路径\n",
    "dataset_path = \"./processed_dataset\"\n",
    "# 检查是否存在处理好的数据集\n",
    "if os.path.exists(dataset_path):\n",
    "    dataset = DatasetDict.load_from_disk(dataset_path)\n",
    "    train_dataset = dataset[\"train\"]\n",
    "    eval_dataset = dataset[\"eval\"]\n",
    "else:\n",
    "    # 加载和处理数据集\n",
    "    dataset = load_dataset(\"databricks/databricks-dolly-15k\")\n",
    "    print(dataset)\n",
    "\n",
    "    def format_prompt(sample):\n",
    "        instruction = f\"### Instruction\\n{sample['instruction']}\"\n",
    "        context = f\"### Context\\n{sample['context']}\" if len(sample[\"context\"]) > 0 else None\n",
    "        response = f\"### Answer\\n{sample['response']}\"\n",
    "        prompt = \"\\n\\n\".join([i for i in [instruction, context, response] if i is not None])\n",
    "        sample[\"prompt\"] = prompt\n",
    "        return sample\n",
    "\n",
    "    dataset = dataset.map(format_prompt)\n",
    "    dataset = dataset.remove_columns(['instruction', 'context', 'response', 'category'])\n",
    "    train_dataset = dataset[\"train\"].select(range(0, 40))\n",
    "    eval_dataset = dataset[\"train\"].select(range(40, 50))\n",
    "    # print(train_dataset)\n",
    "    # print(eval_dataset)\n",
    "    # print(train_dataset[0])\n",
    "    # 保存处理好的数据集\n",
    "    dataset = DatasetDict({\"train\": train_dataset, \"eval\": eval_dataset})\n",
    "    dataset.save_to_disk(dataset_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e401840",
   "metadata": {},
   "source": [
    "## 加载模型\n",
    "在mindnlp中没有找到类似BigBirdPegasusTokenizer的类，所以使用AutoTokenizer。查阅mindnlp，发现有个例程还可以使用PegasusTokenizer，都进行了尝试。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a267c7fe",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindnlp/transformers/tokenization_utils_base.py:1526: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted, and will be then set to `False` by default. \n",
      "  warnings.warn(\n",
      "BigBirdPegasusForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`.`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n",
      "  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n",
      "  - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[MS_ALLOC_CONF]Runtime config:  enable_vmm:True  vmm_align_size:2MB\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] DEVICE(37,fffd60ebb0e0,python):2025-03-04-11:17:17.714.431 [mindspore/ccsrc/transform/acl_ir/op_api_convert.h:114] GetOpApiFunc] Dlsym aclSetAclOpExecutorRepeatable failed!\n",
      "[WARNING] KERNEL(37,fffd60ebb0e0,python):2025-03-04-11:17:17.714.567 [mindspore/ccsrc/transform/acl_ir/op_api_cache.h:54] SetExecutorRepeatable] The aclSetAclOpExecutorRepeatable is unavailable, which results in aclnn cache miss.\n",
      "[WARNING] DEVICE(37,fffd5abce0e0,python):2025-03-04-11:17:17.732.921 [mindspore/ccsrc/transform/acl_ir/op_api_convert.h:114] GetOpApiFunc] Dlsym aclDestroyAclOpExecutor failed!\n"
     ]
    }
   ],
   "source": [
    "model_name = \"google/bigbird-pegasus-large-arxiv\"\n",
    "tokenizer_name = \"google/bigbird-pegasus-large-arxiv\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)\n",
    "# tokenizer = PegasusTokenizer.from_pretrained(tokenizer_name)\n",
    "tokenizer.pad_token = tokenizer.eos_token \n",
    "model = BigBirdPegasusForCausalLM.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbda48b5",
   "metadata": {},
   "source": [
    "## 将数据集预处理为训练格式\n",
    "这里在mindnlp中没有找到类似transformer中DataCollatorForLanguageModeling的工具，所以需要自己编写padding和truncation。\n",
    "这里输出了处理过的数据集与torch的进行对比，保证获得的数据集是一样的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fe44b259",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_dataset: <mindspore.dataset.engine.datasets_user_defined.GeneratorDataset object at 0xffff404b6430>\n",
      "eval_dataset: <mindspore.dataset.engine.datasets_user_defined.GeneratorDataset object at 0xffff45782430>\n",
      "{'input_ids': Tensor(shape=[256], dtype=Int64, value= [  110, 63444, 26323,   463,   117,   114,   110, 84040,  5551, 41676,   152,   110, 63444, 30058,   222, 22600,   108,   114,   110, 84040,  5551, 41676,   117,   142, \n",
      "  8091, 41676,   120,   117,   263,   112, 37525,   523,   108,   120,   117,   108,   112,  1910,   523,   190,   203, 31059,  2274,   143,   544,  1613,   113,   109, \n",
      " 12091,   250, 10008, 44069,   143, 10209,   116,   158,   113,   523,   138,   129, 53136,   141,   109, 41676,   134,   291, 10269,   107,   182,   117,   114,   711, \n",
      "   113,   109, 41676,  1001,   131,   116,  4224,   113, 67669,  7775,   122, 30671,   143, 84040,  2928,   250, 10879,   108,   895, 44069,   143,  6388,   158, 11213, \n",
      "   114,  1934, 28593,   197,  6306, 44069,   143, 11753,   250,   139, 31757,   113,   695,   523,   190,  1613,   141,   114, 41676,  1358,  6381, 15121, 12455,   112, \n",
      " 10796,   120,   695,   523, 13333,   113,   114,  3173,   113,   291,  1613,   107,   110, 63444, 13641,   202,   110, 84040,  5551, 41676,   117,   142,  8091, 41676, \n",
      "   120, 37525,   116,   109,   523,   131,   116,   291, 44069,   134,   291, 10269,   107,   434,   695,   523,   117, 66437,   224,   114,   110, 84040,  5551, 41676, \n",
      "   126,   138,  1910,   190,   109,   291,  1613,   113,   109, 12091,   107,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1, \n",
      "     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1, \n",
      "     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1, \n",
      "     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1]), 'attention_mask': Tensor(shape=[256], dtype=Int64, value= [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, \n",
      " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, \n",
      " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, \n",
      " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, \n",
      " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, \n",
      " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, \n",
      " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, \n",
      " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \n",
      " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \n",
      " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \n",
      " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'labels': Tensor(shape=[256], dtype=Int64, value= [  110, 63444, 26323,   463,   117,   114,   110, 84040,  5551, 41676,   152,   110, 63444, 30058,   222, 22600,   108,   114,   110, 84040,  5551, 41676,   117,   142, \n",
      "  8091, 41676,   120,   117,   263,   112, 37525,   523,   108,   120,   117,   108,   112,  1910,   523,   190,   203, 31059,  2274,   143,   544,  1613,   113,   109, \n",
      " 12091,   250, 10008, 44069,   143, 10209,   116,   158,   113,   523,   138,   129, 53136,   141,   109, 41676,   134,   291, 10269,   107,   182,   117,   114,   711, \n",
      "   113,   109, 41676,  1001,   131,   116,  4224,   113, 67669,  7775,   122, 30671,   143, 84040,  2928,   250, 10879,   108,   895, 44069,   143,  6388,   158, 11213, \n",
      "   114,  1934, 28593,   197,  6306, 44069,   143, 11753,   250,   139, 31757,   113,   695,   523,   190,  1613,   141,   114, 41676,  1358,  6381, 15121, 12455,   112, \n",
      " 10796,   120,   695,   523, 13333,   113,   114,  3173,   113,   291,  1613,   107,   110, 63444, 13641,   202,   110, 84040,  5551, 41676,   117,   142,  8091, 41676, \n",
      "   120, 37525,   116,   109,   523,   131,   116,   291, 44069,   134,   291, 10269,   107,   434,   695,   523,   117, 66437,   224,   114,   110, 84040,  5551, 41676, \n",
      "   126,   138,  1910,   190,   109,   291,  1613,   113,   109, 12091,   107,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1, \n",
      "     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1, \n",
      "     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1, \n",
      "     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1,     1])}\n"
     ]
    }
   ],
   "source": [
    "class TextDataset:\n",
    "    def __init__(self, data):\n",
    "        self.data = data\n",
    "    # 这里就是个padding和truncation截断的操作\n",
    "    def __getitem__(self, index):\n",
    "        index = int(index)\n",
    "        text = self.data[index][\"prompt\"]\n",
    "        inputs = tokenizer(text, padding='max_length', max_length=256, truncation=True)\n",
    "        return (\n",
    "            inputs[\"input_ids\"], \n",
    "            inputs[\"attention_mask\"],\n",
    "            inputs[\"input_ids\"]  # 添加labels\n",
    "        )\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "train_dataset = GeneratorDataset(\n",
    "    TextDataset(train_dataset),\n",
    "    column_names=[\"input_ids\", \"attention_mask\", \"labels\"],  # 添加labels\n",
    "    shuffle=True\n",
    ")\n",
    "eval_dataset = GeneratorDataset(\n",
    "    TextDataset(eval_dataset),\n",
    "    column_names=[\"input_ids\", \"attention_mask\", \"labels\"],  # 添加labels\n",
    "    shuffle=False\n",
    ")\n",
    "print(\"train_dataset:\", train_dataset)\n",
    "print(\"eval_dataset:\", eval_dataset)\n",
    "for data in train_dataset.create_dict_iterator():\n",
    "    print(data)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e3ddebb",
   "metadata": {},
   "source": [
    "## 配置trainer并train\n",
    "这里参数要与torch的训练参数一致，记录当前训练的loss变换然后对比"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d3fe864b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\\r"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 1/100 [00:28<47:21, 28.70s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "|\r"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 10/100 [00:38<01:38,  1.10s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 2.0958, 'learning_rate': 4.5e-05, 'epoch': 1.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/3 [00:00<?, ?it/s]\u001b[A\n",
      " 67%|██████▋   | 2/3 [00:01<00:00,  1.63it/s]\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/\r"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                \n",
      " 10%|█         | 10/100 [00:43<01:38,  1.10s/it]\n",
      "100%|██████████| 3/3 [00:04<00:00,  1.63it/s]\u001b[A\n",
      "                                             \u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 2.592344045639038, 'eval_runtime': 4.9288, 'eval_samples_per_second': 0.609, 'eval_steps_per_second': 0.203, 'epoch': 1.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 20/100 [00:50<01:04,  1.24it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 1.969, 'learning_rate': 4e-05, 'epoch': 2.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/3 [00:00<?, ?it/s]\u001b[A\n",
      "                                                \n",
      " 20%|██        | 20/100 [00:50<01:04,  1.24it/s]\n",
      "100%|██████████| 3/3 [00:00<00:00, 19.53it/s]\u001b[A\n",
      "                                             \u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 2.486072063446045, 'eval_runtime': 0.2738, 'eval_samples_per_second': 10.956, 'eval_steps_per_second': 3.652, 'epoch': 2.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 30/100 [00:57<00:46,  1.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 1.8755, 'learning_rate': 3.5e-05, 'epoch': 3.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/3 [00:00<?, ?it/s]\u001b[A\n",
      "                                                \n",
      " 30%|███       | 30/100 [00:57<00:46,  1.50it/s]\n",
      "100%|██████████| 3/3 [00:00<00:00, 22.78it/s]\u001b[A\n",
      "                                             \u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 2.367415189743042, 'eval_runtime': 0.2442, 'eval_samples_per_second': 12.283, 'eval_steps_per_second': 4.094, 'epoch': 3.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 40/100 [01:04<00:39,  1.54it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 1.8264, 'learning_rate': 3e-05, 'epoch': 4.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/3 [00:00<?, ?it/s]\u001b[A\n",
      "                                                \n",
      " 40%|████      | 40/100 [01:04<00:39,  1.54it/s]\n",
      "100%|██████████| 3/3 [00:00<00:00, 24.96it/s]\u001b[A\n",
      "                                             \u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 2.3535046577453613, 'eval_runtime': 0.241, 'eval_samples_per_second': 12.45, 'eval_steps_per_second': 4.15, 'epoch': 4.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 50/100 [01:11<00:34,  1.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 1.7349, 'learning_rate': 2.5e-05, 'epoch': 5.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/3 [00:00<?, ?it/s]\u001b[A\n",
      "                                                \n",
      " 50%|█████     | 50/100 [01:11<00:34,  1.45it/s]\n",
      "100%|██████████| 3/3 [00:00<00:00, 22.24it/s]\u001b[A\n",
      "                                             \u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 2.2972629070281982, 'eval_runtime': 0.2457, 'eval_samples_per_second': 12.21, 'eval_steps_per_second': 4.07, 'epoch': 5.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 60/100 [01:18<00:24,  1.61it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 1.678, 'learning_rate': 2e-05, 'epoch': 6.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/3 [00:00<?, ?it/s]\u001b[A\n",
      "                                                \n",
      " 60%|██████    | 60/100 [01:18<00:24,  1.61it/s]\n",
      "100%|██████████| 3/3 [00:00<00:00, 24.91it/s]\u001b[A\n",
      "                                             \u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 2.195664882659912, 'eval_runtime': 0.2324, 'eval_samples_per_second': 12.91, 'eval_steps_per_second': 4.303, 'epoch': 6.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 70/100 [01:25<00:20,  1.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 1.6937, 'learning_rate': 1.5e-05, 'epoch': 7.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/3 [00:00<?, ?it/s]\u001b[A\n",
      "                                                \n",
      " 70%|███████   | 70/100 [01:25<00:20,  1.44it/s]\n",
      "100%|██████████| 3/3 [00:00<00:00, 21.99it/s]\u001b[A\n",
      "                                             \u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 2.1624794006347656, 'eval_runtime': 0.2587, 'eval_samples_per_second': 11.596, 'eval_steps_per_second': 3.865, 'epoch': 7.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 80/100 [01:32<00:13,  1.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 1.654, 'learning_rate': 1e-05, 'epoch': 8.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/3 [00:00<?, ?it/s]\u001b[A\n",
      "                                                \n",
      " 80%|████████  | 80/100 [01:32<00:13,  1.48it/s]\n",
      "100%|██████████| 3/3 [00:00<00:00, 23.14it/s]\u001b[A\n",
      "                                             \u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 2.159714460372925, 'eval_runtime': 0.2363, 'eval_samples_per_second': 12.696, 'eval_steps_per_second': 4.232, 'epoch': 8.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 90/100 [01:39<00:06,  1.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 1.6365, 'learning_rate': 5e-06, 'epoch': 9.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/3 [00:00<?, ?it/s]\u001b[A\n",
      "                                                \n",
      " 90%|█████████ | 90/100 [01:39<00:06,  1.51it/s]\n",
      "100%|██████████| 3/3 [00:00<00:00, 22.68it/s]\u001b[A\n",
      "                                             \u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 2.1347262859344482, 'eval_runtime': 0.2604, 'eval_samples_per_second': 11.523, 'eval_steps_per_second': 3.841, 'epoch': 9.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:46<00:00,  1.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 1.7003, 'learning_rate': 0.0, 'epoch': 10.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/3 [00:00<?, ?it/s]\u001b[A\n",
      "                                                 \n",
      "100%|██████████| 100/100 [01:46<00:00,  1.52it/s]\n",
      "100%|██████████| 3/3 [00:00<00:00, 21.63it/s]\u001b[A\n",
      "100%|██████████| 100/100 [01:46<00:00,  1.06s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 2.1257965564727783, 'eval_runtime': 0.2557, 'eval_samples_per_second': 11.733, 'eval_steps_per_second': 3.911, 'epoch': 10.0}\n",
      "{'train_runtime': 106.4446, 'train_samples_per_second': 3.758, 'train_steps_per_second': 0.939, 'train_loss': 1.7863994789123536, 'epoch': 10.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=100, training_loss=1.7863994789123536, metrics={'train_runtime': 106.4446, 'train_samples_per_second': 3.758, 'train_steps_per_second': 0.939, 'train_loss': 1.7863994789123536, 'epoch': 10.0})"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "EPOCHS = 10\n",
    "BATCH_SIZE = 4\n",
    "# 定义训练参数\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='./MindsporeBigBirdFinetune',\n",
    "    overwrite_output_dir=True,\n",
    "    num_train_epochs=EPOCHS,\n",
    "    per_device_train_batch_size=BATCH_SIZE,\n",
    "    per_device_eval_batch_size=BATCH_SIZE,\n",
    "    \n",
    "    save_steps=500,                  # Save checkpoint every 500 steps\n",
    "    save_total_limit=2,              # Keep only the last 2 checkpoints\n",
    "    logging_dir=\"./logs\",            # Directory for logs\n",
    "    logging_steps=100,               # Log every 100 steps\n",
    "    logging_strategy=\"epoch\",\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    eval_steps=500,                  # Evaluation frequency\n",
    "    learning_rate=5e-5,\n",
    "    weight_decay=0.01,               # Weight decay\n",
    ")\n",
    "\n",
    "# 创建trainer\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=eval_dataset,\n",
    "    tokenizer=tokenizer,\n",
    "    compute_metrics=None\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8a575ad",
   "metadata": {},
   "source": [
    "## 查看评估结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5c0833ad",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:00<00:00, 15.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation results: {'eval_loss': 2.1257965564727783, 'eval_runtime': 0.3007, 'eval_samples_per_second': 9.977, 'eval_steps_per_second': 3.326, 'epoch': 10.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "eval_results = trainer.evaluate()\n",
    "print(f\"Evaluation results: {eval_results}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a294ba38",
   "metadata": {},
   "source": [
    "## 保存微调结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0c5b2db5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some non-default generation parameters are set in the model config. These should go into a GenerationConfig file instead.\n",
      "Non-default generation parameters: {'max_length': 256, 'num_beams': 5, 'length_penalty': 0.8}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('./mindNLPTokenizerBigbirdPegasusFinetune/tokenizer_config.json',\n",
       " './mindNLPTokenizerBigbirdPegasusFinetune/special_tokens_map.json',\n",
       " './mindNLPTokenizerBigbirdPegasusFinetune/spiece.model',\n",
       " './mindNLPTokenizerBigbirdPegasusFinetune/added_tokens.json',\n",
       " './mindNLPTokenizerBigbirdPegasusFinetune/tokenizer.json')"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.save_pretrained(\"./mindNLPModelBigbirdPegasusFinetune\")\n",
    "tokenizer.save_pretrained(\"./mindNLPTokenizerBigbirdPegasusFinetune\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20472ce6",
   "metadata": {},
   "source": [
    "## 使用微调模型进行测试\n",
    "虽然loss不断下降并且比torch的更好。但是由于两个都是短暂微调训练，可以看到语言模型实际效果并不好，输出结果不解其意。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8e3fab68",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "in,, have a but\n"
     ]
    }
   ],
   "source": [
    "fine_tuned_model = BigBirdPegasusForCausalLM.from_pretrained(\"./mindNLPModelBigbirdPegasusFinetune\")\n",
    "fine_tuned_tokenizer = PegasusTokenizer.from_pretrained(\"./mindNLPTokenizerBigbirdPegasusFinetune\")\n",
    "inputs = \"Hello, my dog is cute\"\n",
    "input_tokens = fine_tuned_tokenizer(inputs, return_tensors=\"ms\")\n",
    "outputs = fine_tuned_model(**input_tokens)\n",
    "logits = outputs.logits\n",
    "# 使用 argmax 获取预测的 token ID\n",
    "from mindspore import ops\n",
    "predicted_token_ids = ops.argmax(logits, dim=-1)  # 在最后一个维度（vocab_size）上取 argmax\n",
    "# 解码生成的文本\n",
    "generated_text = fine_tuned_tokenizer.decode(predicted_token_ids[0].asnumpy().tolist(), skip_special_tokens=True)\n",
    "print(generated_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b56a68c9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b5ed92b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
